{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Callbacks\n",
    "\n",
    "> Callbacks which work with a learner's data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CollectDataCallback(Callback):\n",
    "    \"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing\"\n",
    "    def before_fit(self): self.data = L()\n",
    "    def after_batch(self): \n",
    "        self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CudaCallback(Callback):\n",
    "    \"Move data to CUDA device\"\n",
    "    def __init__(self, device=None): self.device = ifnone(device, default_device())\n",
    "    def before_batch(self): self.learn.xb,self.learn.yb = to_device(self.xb),to_device(self.yb)\n",
    "    def before_fit(self): self.model.to(self.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You don't normally need to use this Callback, because fastai's `DataLoader` will handle passing data to a device for you. However, if you already have a plain PyTorch DataLoader and can't change it for some reason, you can use this transform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cuda\n",
    "learn = synth_learner(cbs=CudaCallback)\n",
    "learn.model\n",
    "learn.fit(1)\n",
    "test_eq(next(learn.model.parameters()).device.type, 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@log_args(but_as=TfmdDL.__init__)\n",
    "@delegates()\n",
    "class WeightedDL(TfmdDL):\n",
    "    def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        wgts = array([1.]*len(dataset) if wgts is None else wgts)\n",
    "        self.wgts = wgts/wgts.sum()\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.n==0: return []\n",
    "        if not self.shuffle: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.n, p=self.wgts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 160\n",
    "dsets = Datasets(torch.arange(n).float())\n",
    "dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)\n",
    "learn = synth_learner(data=dls, cbs=CollectDataCallback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(1)\n",
    "t = concat(*learn.collect_data.data.itemgot(0,0))\n",
    "plt.hist(t);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@log_args(but_as=TfmdDL.__init__)\n",
    "@delegates()\n",
    "class PartialDL(TfmdDL):\n",
    "    \"Select randomly partial quantity of data at each epoch\"\n",
    "    def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        self.partial_n = min(partial_n, self.n) if partial_n else None\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.partial_n is None: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.partial_n, replace=False))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.partial_n is None: return super().__len__()\n",
    "        return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):\n",
    "    \"Create a partial dataloader `PartialDL` for the training set\"\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.partial_dataloaders(partial_n=32, bs=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(dls[0])==2\n",
    "for batch in dls[0]:\n",
    "    assert len(batch[0])==16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
